{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af841f26-7fe5-4078-a1ff-060dbf5b93a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import sklearn\n",
    "from pol.datasets.linear_regression import create_nd_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "512f1340-c3db-439c-b698-9a9d157b7482",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1.9068, -0.4792,  1.6930, -0.9532, -0.7236, -1.5276, -1.0329, -0.7259,\n",
      "         1.8563, -0.9454, -0.2360,  0.4395,  1.4545,  1.4550,  0.6995,  0.6395,\n",
      "         0.9430, -1.1090, -1.3117,  1.4817], dtype=torch.float64)\n",
      "X shape: torch.Size([20, 20]), Y shape: torch.Size([20, 1])\n"
     ]
    }
   ],
   "source": [
    "theta = 0.839\n",
    "dataset = create_nd_dataset(dim=20, num_sample=20, seed=1234)\n",
    "X, Y = dataset[:]\n",
    "print(f'X shape: {X.shape}, Y shape: {Y.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45aa23b6-3d81-4c6e-bc4a-5818737af9c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_lasso(X, Y, theta):\n",
    "    from sklearn import linear_model\n",
    "    print('lasso alpha={}'.format(theta))\n",
    "    clf = linear_model.Lasso(\n",
    "        alpha=theta / 2,\n",
    "        fit_intercept=False,\n",
    "        max_iter=100)\n",
    "    clf.fit(X, Y)\n",
    "    W = clf.coef_\n",
    "    return W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9af25ec7-d3b4-42a1-ae35-f9280b396a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss(W):\n",
    "    W = W.unsqueeze(-1) # Nx1\n",
    "    l1 = (torch.matmul(X, W) - Y).square().squeeze(-1)\n",
    "    l2 = theta * W.abs().squeeze(-1)\n",
    "    return (l1 + l2).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "52bec643-8aba-4822-9a7f-5aad4956fa44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lasso alpha=0.839\n",
      "W_sk: tensor([ 1.0047, -0.0000,  0.5792, -0.0000, -0.0000, -1.1364, -0.5494, -0.0000,\n",
      "         0.3137,  0.0000,  0.2430,  0.0000,  0.0610,  0.4075, -0.0000,  0.7827,\n",
      "         0.5565,  0.0000, -0.0000,  1.6964])\n",
      "L_sk: 3.734462022781372\n"
     ]
    }
   ],
   "source": [
    "W_sk = solve_lasso(X.numpy(), Y.squeeze(-1).numpy(), theta=theta)\n",
    "W_sk = torch.from_numpy(W_sk)\n",
    "L_sk = compute_loss(W_sk)\n",
    "print(f'W_sk: {W_sk}')\n",
    "print(f'L_sk: {L_sk}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9c374b2b-33db-4a4e-9543-df2eb9334448",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W_pol: tensor([-3.0000e-04, -4.0000e-04,  8.3600e-02,  1.9500e-02,  4.0000e-04,\n",
      "        -2.0000e-03, -1.4000e-03, -1.2000e-03,  2.1810e-01, -4.0000e-04,\n",
      "        -1.0100e-02, -0.0000e+00, -4.2100e-02,  1.9500e-02, -1.6940e+00,\n",
      "         2.4000e-03, -6.0000e-04, -5.0000e-04,  1.0000e-03, -1.8000e-03])\n",
      "L_pol: 22.54592514038086\n"
     ]
    }
   ],
   "source": [
    "W_pol = torch.tensor([-0.0003,-0.0004,0.0836,0.0195,0.0004,-0.0020,-0.0014,-0.0012,0.2181,-0.0004,-0.0101,-0.0000,-0.0421,0.0195,-1.6940,0.0024,-0.0006,-0.0005,0.0010,-0.0018])\n",
    "L_pol = compute_loss(W_pol)\n",
    "print(f'W_pol: {W_pol}')\n",
    "print(f'L_pol: {L_pol}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b796c1be-39d3-4549-9135-006788d9d232",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09f61f1-659c-4064-8d08-e44bd4222863",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
